{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-09-04T08:58:55.312756Z",
     "start_time": "2024-09-04T08:58:53.761807Z"
    }
   },
   "source": [
    "import os\n",
    "import torch\n",
    "from d2l import torch as d2l"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 1. 下载和预处理数据集",
   "id": "96130e28d43d5725"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:03:15.074535Z",
     "start_time": "2024-09-04T09:03:15.039824Z"
    }
   },
   "cell_type": "code",
   "source": [
    "data_dir = './fra-eng_dataset'\n",
    "\n",
    "\n",
    "def read_data_nmt():\n",
    "    \"\"\"载入“英语－法语”数据集\"\"\"\n",
    "    with open(os.path.join(data_dir, 'fra.txt'), 'r', encoding='utf-8') as f:\n",
    "        return f.read()\n",
    "\n",
    "\n",
    "raw_text = read_data_nmt()\n",
    "print(raw_text[:75])"
   ],
   "id": "4c99b086f7f6f056",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go.\tVa !\n",
      "Hi.\tSalut !\n",
      "Run!\tCours !\n",
      "Run!\tCourez !\n",
      "Who?\tQui ?\n",
      "Wow!\tÇa alors !\n",
      "\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:08:33.950225Z",
     "start_time": "2024-09-04T09:08:31.939969Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 我们用空格代替不间断空格（non-breaking space）， 使用小写字母替换大写字母，并在单词和标点符号之间插入空格。\n",
    "def preprocess_nmt(text):\n",
    "    \"\"\"预处理“英语－法语”数据集\"\"\"\n",
    "\n",
    "    def no_space(char, prev_char):\n",
    "        return char in set(',.!?') and prev_char != ' '\n",
    "\n",
    "    # 使用空格替换不间断空格\n",
    "    # 使用小写字母替换大写字母\n",
    "    text = text.replace('\\u202f', ' ').replace('\\xa0', ' ').lower()\n",
    "    # 在单词和标点符号之间插入空格\n",
    "    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text)]\n",
    "    return ''.join(out)\n",
    "\n",
    "\n",
    "text = preprocess_nmt(raw_text)\n",
    "print(text[:80])"
   ],
   "id": "711a41429308ca72",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "go .\tva !\n",
      "hi .\tsalut !\n",
      "run !\tcours !\n",
      "run !\tcourez !\n",
      "who ?\tqui ?\n",
      "wow !\tça alors !\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 2. 词元化",
   "id": "785adc5f6e968b5a"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:10:19.038733Z",
     "start_time": "2024-09-04T09:10:18.460246Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#@save\n",
    "def tokenize_nmt(text, num_examples=None):\n",
    "    \"\"\"词元化“英语－法语”数据数据集\"\"\"\n",
    "    source, target = [], []\n",
    "    for i, line in enumerate(text.split('\\n')):\n",
    "        if num_examples and i > num_examples:\n",
    "            break\n",
    "        parts = line.split('\\t')\n",
    "        if len(parts) == 2:\n",
    "            source.append(parts[0].split(' '))\n",
    "            target.append(parts[1].split(' '))\n",
    "    return source, target\n",
    "\n",
    "\n",
    "source, target = tokenize_nmt(text)\n",
    "source[:6], target[:6]"
   ],
   "id": "101d72f3352b68fe",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([['go', '.'],\n",
       "  ['hi', '.'],\n",
       "  ['run', '!'],\n",
       "  ['run', '!'],\n",
       "  ['who', '?'],\n",
       "  ['wow', '!']],\n",
       " [['va', '!'],\n",
       "  ['salut', '!'],\n",
       "  ['cours', '!'],\n",
       "  ['courez', '!'],\n",
       "  ['qui', '?'],\n",
       "  ['ça', 'alors', '!']])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 3. 词汇表",
   "id": "fd576f8eee62cb7c"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:13:38.866921Z",
     "start_time": "2024-09-04T09:13:38.691438Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 构造词汇表，对于出现频率 < min_freg(2) 的次丢弃\n",
    "src_vocab = d2l.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
    "len(src_vocab)"
   ],
   "id": "5b6576646dbeee98",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10012"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 4. 加载数据集 ",
   "id": "26049b5b29678d9d"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:16:46.832201Z",
     "start_time": "2024-09-04T09:16:46.821314Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 设定相同的长度num_steps，如果文本序列的词元数目少于num_steps时，在其末尾添加特定的“<pad>”词元，直到长度达到num_steps\n",
    "# 反之，截断文本序列，只取其前num_steps个词元，并且丢弃剩余的词元\n",
    "def truncate_pad(line, num_steps, padding_token):\n",
    "    \"\"\"截断或填充文本序列\"\"\"\n",
    "    if len(line) > num_steps:\n",
    "        return line[:num_steps]  # 截断\n",
    "    return line + [padding_token] * (num_steps - len(line))  # 填充\n",
    "\n",
    "\n",
    "truncate_pad(src_vocab[source[0]], 10, src_vocab['<pad>'])\n",
    "# 这里长度只有2，前两个对应标号[47, 4]，其他用标号[1]即'<pad>'填充"
   ],
   "id": "7d2533b663921c42",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:20:31.620711Z",
     "start_time": "2024-09-04T09:20:31.607322Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def build_array_nmt(lines, vocab, num_steps):\n",
    "    \"\"\"将机器翻译的文本序列转换成长度为num_steps的新句子\"\"\"\n",
    "    lines = [vocab[l] for l in lines]\n",
    "    lines = [l + [vocab['<eos>']] for l in lines]  # <eos> end of sentence 标记句子结尾\n",
    "    array = torch.tensor([truncate_pad(l, num_steps, vocab['<pad>']) for l in lines])\n",
    "    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)  # 保存每句话实际有效的长度，排除填充的<pad>\n",
    "    return array, valid_len"
   ],
   "id": "7a3317297395431b",
   "outputs": [],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 5. [封装]返回数据集的迭代器和词汇表",
   "id": "3732667ffb6faacf"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:28:51.392757Z",
     "start_time": "2024-09-04T09:28:51.382849Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def load_data_nmt(batch_size, num_steps, num_examples=600):\n",
    "    \"\"\"返回翻译数据集的迭代器和词表\"\"\"\n",
    "    # 1. 下载数据集并预处理(针对标点符号)\n",
    "    text = preprocess_nmt(read_data_nmt())\n",
    "\n",
    "    # 2. 词元化，返回原始数据和目标数据\n",
    "    source, target = tokenize_nmt(text, num_examples)\n",
    "\n",
    "    # 3. 词汇表\n",
    "    src_vocab = d2l.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
    "    tgt_vocab = d2l.Vocab(target, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
    "\n",
    "    # 4. 将机器翻译的文本序列转换成长度为num_steps的新句子\n",
    "    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n",
    "    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n",
    "\n",
    "    # 5. 构造数据迭代器\n",
    "    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n",
    "    data_iter = d2l.load_array(data_arrays, batch_size)\n",
    "\n",
    "    return data_iter, src_vocab, tgt_vocab"
   ],
   "id": "50b4826bad7b61a4",
   "outputs": [],
   "execution_count": 8
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 6. 测试",
   "id": "3b19212b0e3517c"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:30:23.537839Z",
     "start_time": "2024-09-04T09:30:21.386679Z"
    }
   },
   "cell_type": "code",
   "source": [
    "train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)\n",
    "for X, X_valid_len, Y, Y_valid_len in train_iter:\n",
    "    print('X:', X.type(torch.int32))\n",
    "    print('X的有效长度:', X_valid_len)\n",
    "    print('Y:', Y.type(torch.int32))\n",
    "    print('Y的有效长度:', Y_valid_len)\n",
    "    break"
   ],
   "id": "b2d97a024432f15d",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X: tensor([[ 87,  22,   4,   3,   1,   1,   1,   1],\n",
      "        [  7, 101,   4,   3,   1,   1,   1,   1]], dtype=torch.int32)\n",
      "X的有效长度: tensor([4, 4])\n",
      "Y: tensor([[177, 178,  25,   4,   3,   1,   1,   1],\n",
      "        [  6,   7, 158,   4,   3,   1,   1,   1]], dtype=torch.int32)\n",
      "Y的有效长度: tensor([5, 5])\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-04T09:31:56.550909Z",
     "start_time": "2024-09-04T09:31:56.535312Z"
    }
   },
   "cell_type": "code",
   "source": "X.shape, X_valid_len.shape, Y.shape, Y_valid_len.shape",
   "id": "78204a88a80c5078",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([2, 8]), torch.Size([2]), torch.Size([2, 8]), torch.Size([2]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "`注意到`\n",
    "X: X是(2, 8), 2表示批量大小，即一次训练2个样本(句子)；8是时间步数，即句子长度"
   ],
   "id": "ecf8a2669fbaee04"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "例如：\n",
    "```\n",
    "X = tensor(\n",
    "[\n",
    "    [ 87,  22,   4,   3,   1,   1,   1,   1],\n",
    "    [  7, 101,   4,   3,   1,   1,   1,   1],\n",
    "], dtype=torch.int32)\n",
    "```\n",
    "表示：第一句话的标号为[87, 22, 4, 3, 1, 1, 1, 1]，其中后面4个为<pad>补充，并非句子实际长度，所以\n",
    "\n",
    "`X_valid_len = tensor([4, 4])`标记了第一句话有效长度只有 8 - 4 = 4"
   ],
   "id": "a69aa9e48bcdc0b5"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
